import numpy as np
import matplotlib.pyplot as plt

np.random.seed(2024)

def plot_UD_T():
    m = 100
    K, T = 100, 5000

    Delta_upper = [0]
    Delta_upper_final= [0]
    Delta_lower = [0]
    Delta_lower_final= [0]

    gamma_prime = (pow(1 + lr_in * gamma_tr, 2 * K) - 1) / gamma_tr
    gamma = gamma_prime
    L_prime = (pow(1 + lr_in * gamma_tr, K) - 1) / gamma_tr
    L = 0.1 + 1.1 * L_prime

    for t in range(1, T + 1):
        delta_upper = (1 + (1 - 1/m) * lr_out_c * gamma / t) * Delta_upper[t-1] + 2 * lr_out_c * L / (t * m)
        delta_upper_final = 2 * lr_out_c * L * ((1 + (1 - 1/m) * lr_out_c * gamma_prime) * (t ** ((1 - 1/m) * lr_out_c * gamma_prime)) - 1) / (m * (1 - 1/m) * lr_out_c * gamma)
        delta_lower = (1 + (1 - 1/m) * lr_out_c * gamma_prime / t) * Delta_lower[t-1] + 2 * lr_out_c * L_prime / (t * m)
        delta_lower_final = 2 * lr_out_c * L_prime * ( ( (0.5 * (t+1)) ** np.log(1 + (1 - 1/m) * lr_out_c * gamma_prime) ) - 1 ) / (m * np.log(1 + (1 - 1/m) * lr_out_c * gamma_prime))
        
        Delta_upper.append(delta_upper)
        Delta_upper_final.append(delta_upper_final)
        Delta_lower.append(delta_lower)
        Delta_lower_final.append(delta_lower_final)

    plt.figure(figsize=(8, 6))
    plt.rcParams['font.family'] = ['Times New Roman']
    plt.plot(Delta_upper[start-1:], Delta_upper_final[start-1:], color='#DC143C',label='Scaled upper bound', linewidth=3)
    plt.plot(Delta_upper[start-1:], Delta_lower[start-1:], color='#2ca02c',label='Resursive lower bound', linestyle='dashed', linewidth=3)
    plt.plot(Delta_upper[start-1:], Delta_lower_final[start-1:], color='#1E90FF',label='Sacled lower bound', linewidth=3)

    plt.xlabel(r'Recursive upper bound', labelpad=8, fontsize = 22)
    plt.ylabel(r'Other bounds', labelpad=8, fontsize = 22)
    plt.yticks(fontproperties = 'Times New Roman', size = 15)
    plt.xticks(fontproperties = 'Times New Roman', size = 15)

    plt.legend(fontsize=15)
    plt.savefig('./recursion_scaled_bounds.png', bbox_inches='tight', dpi=800)

    plt.show()


def plot_UD_K():
    m = 100
    T = 1000

    Delta_upper = [0]
    Delta_upper_final= [0]
    Delta_lower = [0]
    Delta_lower_final= [0]

    for K in [50, 75, 100, 125, 150, 175, 200]:
        delta_upper = delta_lower = 0
        gamma_prime = (pow(1 + lr_in * gamma_tr, 2 * K) - 1) / gamma_tr
        gamma = gamma_prime
        L_prime = (pow(1 + lr_in * gamma_tr, K) - 1) / gamma_tr
        L = 0.1 + 1.1 * L_prime

        for t in range(1, T + 1):
            delta_upper = (1 + (1 - 1/m) * lr_out_c * gamma / t) * delta_upper + 2 * lr_out_c * L / (t * m)
            delta_lower = (1 + (1 - 1/m) * lr_out_c * gamma_prime / t) * delta_lower + 2 * lr_out_c * L_prime / (t * m)
        
        delta_upper_final = 2 * lr_out_c * L * ((1 + (1 - 1/m) * lr_out_c * gamma_prime) * (T ** ((1 - 1/m) * lr_out_c * gamma_prime)) - 1) / (m * (1 - 1/m) * lr_out_c * gamma)
        delta_lower_final = 2 * lr_out_c * L_prime * ( ( (0.5 * (T+1)) ** np.log(1 + (1 - 1/m) * lr_out_c * gamma_prime) ) - 1 ) / (m * np.log(1 + (1 - 1/m) * lr_out_c * gamma_prime))

        Delta_upper.append(delta_upper)
        Delta_upper_final.append(delta_upper_final)
        Delta_lower.append(delta_lower)
        Delta_lower_final.append(delta_lower_final)

    plt.figure(figsize=(8, 6))
    plt.rcParams['font.family'] = ['Times New Roman']
    plt.plot(Delta_upper, Delta_upper_final, color='#DC143C',label='Scaled upper bound', linewidth=3)
    plt.plot(Delta_upper, Delta_lower, color='#2ca02c',label='Resursive lower bound', linestyle='dashed', linewidth=3)
    plt.plot(Delta_upper, Delta_lower_final, color='#1E90FF',label='Sacled lower bound', linewidth=3)

    plt.xlabel(r'Recursive upper bound', labelpad=8, fontsize = 22)
    plt.ylabel(r'Other bounds', labelpad=8, fontsize = 22)
    plt.yticks(fontproperties = 'Times New Roman', size = 15)
    plt.xticks(fontproperties = 'Times New Roman', size = 15)

    plt.legend(fontsize=15)
    plt.savefig('./recursion_scaled_bounds_K.png', bbox_inches='tight', dpi=800)

    plt.show()


def plot_UD_m():
    K, T = 100, 1000
    Delta_upper = [0]
    Delta_upper_final= [0]
    Delta_lower = [0]
    Delta_lower_final= [0]

    for m in [100, 300, 500, 1000, 1500, 2000]:
        delta_upper = delta_lower = 0
        gamma_prime = (pow(1 + lr_in * gamma_tr, 2 * K) - 1) / gamma_tr
        gamma = gamma_prime
        L_prime = (pow(1 + lr_in * gamma_tr, K) - 1) / gamma_tr
        L = 0.1 + 1.1 * L_prime

        for t in range(1, T + 1):
            delta_upper = (1 + (1 - 1/m) * lr_out_c * gamma / t) * delta_upper + 2 * lr_out_c * L / (t * m)
            delta_lower = (1 + (1 - 1/m) * lr_out_c * gamma_prime / t) * delta_lower + 2 * lr_out_c * L_prime / (t * m)
        
        delta_upper_final = 2 * lr_out_c * L * ((1 + (1 - 1/m) * lr_out_c * gamma_prime) * (T ** ((1 - 1/m) * lr_out_c * gamma_prime)) - 1) / (m * (1 - 1/m) * lr_out_c * gamma)
        delta_lower_final = 2 * lr_out_c * L_prime * ( ( (0.5 * (T+1)) ** np.log(1 + (1 - 1/m) * lr_out_c * gamma_prime) ) - 1 ) / (m * np.log(1 + (1 - 1/m) * lr_out_c * gamma_prime))

        Delta_upper.append(delta_upper)
        Delta_upper_final.append(delta_upper_final)
        Delta_lower.append(delta_lower)
        Delta_lower_final.append(delta_lower_final)

    plt.figure(figsize=(8, 6))
    plt.rcParams['font.family'] = ['Times New Roman']
    plt.plot(Delta_upper, Delta_upper_final, color='#DC143C',label='Scaled upper bound', linewidth=3)
    plt.plot(Delta_upper, Delta_lower, color='#2ca02c',label='Resursive lower bound', linestyle='dashed', linewidth=3)
    plt.plot(Delta_upper, Delta_lower_final, color='#1E90FF',label='Sacled lower bound', linewidth=3)

    plt.xlabel(r'Recursive upper bound', labelpad=8, fontsize = 22)
    plt.ylabel(r'Other bounds', labelpad=8, fontsize = 22)
    plt.yticks(fontproperties = 'Times New Roman', size = 15)
    plt.xticks(fontproperties = 'Times New Roman', size = 15)

    plt.legend(fontsize=15)
    plt.savefig('./recursion_scaled_bounds_m.png', bbox_inches='tight', dpi=800)

    plt.show()

def matrix_A(d):
    """
    Create a diagonal matrix with the first diagonal element as -1 and the rest as 1.
    
    :param d: Dimension of the matrix
    :return: Diagonal matrix A
    """
    diagonal_elements = np.ones(d)
    diagonal_elements[0] = -1
    return np.diag(diagonal_elements)

def generate_dataset(n, d, noise_mean=0.1, noise_std=0.1):
    """
    Generate a dataset with features from a normal distribution with mean 1 and variance 1.
    Labels are calculated as the inner product of features with a vector of ones plus Gaussian noise.

    :param n: Number of samples
    :param d: Number of features
    :param noise_mean: Mean of the Gaussian noise
    :param noise_std: Standard deviation of the Gaussian noise
    :return: Feature matrix X and label vector y
    """
    X = np.random.randn(n, d) + 1
    e = np.random.normal(noise_mean, noise_std, n)
    y = X @ np.ones(d) + e
    return X, y

def create_twin_set(X_base, y_base, x, y, x_tilde, y_tilde):
    """
    Create twin validation sets by adding specific data points to the base validation set.

    :param X_base: Base feature matrix
    :param y_base: Base label vector
    :param x: Feature vector to add to X_base
    :param y: Label to add to y_base
    :param x_tilde: Feature vector to add to X_base for the second validation set
    :param y_tilde: Label to add to y_base for the second validation set
    :return: Two twin validation sets (X_val, y_val) and (X_tilde_val, y_tilde_val)
    """
    X_val = np.vstack([X_base, x])
    y_val = np.append(y_base, y)
    X_tilde_val = np.vstack([X_base, x_tilde])
    y_tilde_val = np.append(y_base, y_tilde)
    return X_val, y_val, X_tilde_val, y_tilde_val

def plot_bilevel_optimization():
    """
    Perform bilevel optimization with given parameters.

    :param d: Feature dimension
    :param n: Training set size
    :param m: Validation set size
    :param K: Inner iteration steps
    :param T: Outer iteration steps
    :param lr_in: Inner learning rate
    :param lr_out_c: Initial outer learning rate
    :param start: Start iteration for recording values
    """
    n, m = 100, 100
    K, T = 100, 5000
    A = matrix_A(d)

    X_tr, y_tr = generate_dataset(n, d)
    X_val_base, y_val_base = generate_dataset(m-1, d)
    v_1 = np.array([1] + [0] * (d - 1))
    X_val, y_val, X_tilde_val, y_tilde_val = create_twin_set(X_val_base, y_val_base, v_1, 1, -v_1, 1)

    lamb = np.zeros(d)
    lamb_tilde = np.zeros(d)

    gamma_prime = (pow(1 + lr_in * gamma_tr, 2 * K) - 1) / gamma_tr
    gamma = gamma_prime
    lower_power = np.log(1 + (1 - 1 / m) * lr_out_c * gamma_prime)
    upper_power = (1 - 1 / m) * lr_out_c * gamma

    lamb_distances = []
    lower_bounds = []
    upper_bounds = []

    for t in range(1, T + 1):
        theta = np.zeros(d)
        lr_out = lr_out_c / t

        for k in range(K):
            idx_tr = np.random.randint(n)
            xj, yj = X_tr[idx_tr], y_tr[idx_tr]
            gradient_tr = A @ theta + lamb - yj * xj  
            theta = theta - lr_in * gradient_tr

        idx_val = np.random.randint(m)
        xi, yi = X_val[idx_val], y_val[idx_val]
        xi_tilde, yi_tilde = X_tilde_val[idx_val], y_tilde_val[idx_val]

        B = np.eye(d) - lr_in * A
        matrix_sum = np.zeros((d, d))
        B_power = np.eye(d)
        for k in range(K):
            matrix_sum += B_power
            B_power = B_power @ B
        jacobian = -lr_in * matrix_sum

        gradient_val = theta + jacobian @ (A @ theta + lamb - yi * xi)
        lamb = lamb - lr_out * gradient_val

        gradient_tilde_val = theta + jacobian @ (A @ theta + lamb_tilde - yi_tilde * xi_tilde)
        lamb_tilde = lamb_tilde - lr_out * gradient_tilde_val

        lamb_div = lamb - lamb_tilde
        lamb_distance = np.sqrt(lamb_div.T @ lamb_div)

        lower_bound = pow(t, lower_power) / m
        upper_bound = pow(t, upper_power) / m

        if t >= start:
            lamb_distances.append(lamb_distance)
            lower_bounds.append(lower_bound)
            upper_bounds.append(upper_bound)

    upper_bounds = np.array(upper_bounds)
    lower_bounds = np.array(lower_bounds)

    plt.figure(figsize=(8, 6))
    plt.rcParams['font.family'] = ['Times New Roman']
    plt.plot(lamb_distances, upper_bounds, label= 'Upper bound', color='#DC143C', linewidth=3)
    plt.plot(lamb_distances, lower_bounds, label= 'Lower bound', color='#1E90FF', linewidth=3)
    # Fit linear trends
    upper_trend = np.polyfit(lamb_distances, upper_bounds, 1)
    lower_trend = np.polyfit(lamb_distances, lower_bounds, 1)
    # Calculate trend lines
    upper_trend_line = np.polyval(upper_trend, lamb_distances)
    lower_trend_line = np.polyval(lower_trend, lamb_distances)
    # Plot trend lines
    plt.plot(lamb_distances, upper_trend_line, label='Linear fitting of the upper bound', linestyle='--', color='#CD5C5C', linewidth=2)
    plt.plot(lamb_distances, lower_trend_line, label='Linear fitting of the lower bound', linestyle='--', color='#4e79a7', linewidth=2)
    
    plt.xlabel('Hyperparameter distance', labelpad=8, fontsize=22)
    plt.ylabel('Upper and lower bounds', labelpad=8, fontsize=22)
    plt.yticks(fontproperties='Times New Roman', size=15)
    plt.xticks(fontproperties='Times New Roman', size=15)
    plt.legend(fontsize=15)
   
    plt.savefig('./optimized_thm5.5_bounds_distances.png', bbox_inches='tight', dpi=800)
    plt.show()


# Training settings
d = 2
lr_in = 0.01
lr_out_c = 0.01
gamma_tr = 1
start = 1001


plot_bilevel_optimization()
plot_UD_T()
plot_UD_K()
plot_UD_m()



